# !python3 -m pip install imageio
from datetime import datetime
from IPython.display import Image
import matplotlib.pyplot as plt
import numpy as np
from imageio import imread
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import datasets, transforms
from torchvision.utils import save_image
from tqdm import tqdm
Image shape is (218,178,3).
202,599 images.
n_images = 202599
start = np.random.randint(1, n_images-20)
FOLDER = 'img_align_celeba/0/'
FILES = ["{}{:06}.jpg".format(FOLDER, i) for i in range(start+1, start+21)]
scale = 0.8
fig, axs = plt.subplots(2, 5, figsize=(scale*20,scale*10))
for i, ax in enumerate(axs.flatten()):
img = imread(FILES[i])
ax.imshow(imread(FILES[i]))
ax.set_xticks([])
ax.set_yticks([])
ax.set_aspect('equal')
plt.subplots_adjust(wspace=0, hspace=0)
plt.show()
def display_x_tensor(t, title=None):
npimg = t.cpu().numpy().clip(0, 1)
plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')
if title:
plt.title(title)
plt.show()
def generate_z(n=1):
return torch.tensor(np.random.uniform(low=-1, high=1, size=(n,100))).float()
class CelebDataset(Dataset):
def __init__(self, folder='img_align_celeba/'):
self.dataset_folder = datasets.ImageFolder(folder, transform=transforms.Compose([
transforms.Resize((128,128)),
transforms.ToTensor()
]))
def __getitem__(self,index):
"""
Returns:
x (3, 512, 512) - float in [0,1).
y (512, 512) - byte in {0,1}.
"""
img = self.dataset_folder[index][0] # (3, 218, 178) torch.FloatTensor with values in [0, 1]
return img
def __len__(self):
assert len(self.dataset_folder) == n_images, "Missing images mismatch. {} != {}".format(
len(self.dataset_folder), n_images)
return len(self.dataset_folder)
celeb_dataset = CelebDataset()
n_dataset = len(celeb_dataset)
img = celeb_dataset[np.random.randint(n_dataset)]
display_x_tensor(img)
def generator_overfit(n_epochs=101, n_image_every=100, lr=0.001):
gen = Generator()
y_truth = celeb_dataset[0]
z_const = generate_z(n=1)
objective = nn.MSELoss()
optimizer = optim.Adam(gen.parameters(), lr=lr)
for i in range(n_epochs):
z = torch.tensor(np.random.uniform(low=0, high=1, size=(100,)).astype(np.float32)).view(1,100)
# Clear the gradients
optimizer.zero_grad()
# Generate an image
y_hat = gen(z_const)
# Calculate the loss and gradients
loss = objective(y_truth, y_hat)
loss.backward()
optimizer.step()
if i%n_image_every == 0:
print("loss: {:.3f}, max: {:.3f}, min: {:.3f}".format(loss.item(), y_hat[0].max(), y_hat[0].min()))
display_x_tensor(y_hat[0].detach())
# generator_overfit()
def discriminator_overfit(n_epochs=101, n_print_every=100, lr=0.001):
dis = Discriminator()
objective = nn.BCELoss()
optimizer = optim.Adam(dis.parameters(), lr=lr)
for i in range(n_epochs):
# Clear gradients
optimizer.zero_grad()
# Forward pass
x = celeb_dataset[0].unsqueeze(0)
# print(x.shape)
label = torch.tensor(np.array([1])).float().view((1,1))
y_hat = dis(x)
# Calculate loss and gradients
# print("y_hat: {}, label: {}".format(y_hat.dtype, label.dtype))
# print("y_hat: {}, label: {}".format(y_hat.shape, label.shape))
# loss = objective(y_hat, label)
loss = y_hat
loss.backward()
# Update weights
optimizer.step()
# Print progress
if i%n_print_every == 0:
print("loss: {:.3f}, max: {:.3f}, min: {:.3f}".format(loss.item(), y_hat[0].max(), y_hat[0].min()))
dis(x)
# discriminator_overfit()
How is a transposed convolution of size 5 supposed to generate an even-sized output from an even-sized input?

class Generator(nn.Module):
def __init__(self):
super().__init__()
kernel_size_c = 3
stride_c = 1
padding_c = 1
kernel_size = 4
stride = 2
padding = 1
# starting size: (n, 100,)
self.project = nn.Linear(100, 4*4*1024)
# post_projection size: (n, 1024, 4, 4)
# upsampled to (n, 1024, 8, 8)
self.conv1 = nn.Sequential(
# nn.Conv2d(1024, 512, kernel_size_c, stride_c, padding_c),
nn.ConvTranspose2d(in_channels=1024, out_channels=512,
kernel_size=kernel_size, stride=stride, padding=padding),
nn.ReLU(),
nn.BatchNorm2d(num_features=512)
)
# post_conv1 size: (n, 512, 8, 8)
self.conv2 = nn.Sequential(
# nn.Conv2d(512, 256, kernel_size_c, stride_c, padding_c),
nn.ConvTranspose2d(512, 256, kernel_size, stride, padding),
nn.ReLU(),
nn.BatchNorm2d(256)
)
# post_conv2 size: (256, 16, 16)
self.conv3 = nn.Sequential(
# nn.Conv2d(256, 128, kernel_size_c, stride_c, padding_c),
nn.ConvTranspose2d(256, 128, kernel_size, stride, padding),
nn.ReLU(),
nn.BatchNorm2d(128)
)
# post_conv3 size: (128, 32, 32)
self.conv4 = nn.Sequential(
# nn.Conv2d(128, 64, kernel_size_c, stride_c, padding_c),
nn.ConvTranspose2d(128, 64, kernel_size, stride, padding),
nn.ReLU(),
nn.BatchNorm2d(64)
)
# post_conv4 size: (64, 64, 64)
# then upsample to remove checkerboard artifacts
# size: (64, 128, 128)
self.conv5 = nn.Sequential(
nn.Conv2d(64, 3, kernel_size_c, stride_c, padding_c),
# nn.ConvTranspose2d(64, 3, kernel_size, stride, padding),
nn.Sigmoid()
)
# post_conv5 size: (3, 128, 128)
def forward(self, z):
z1 = self.project(z).view(-1, 1024, 4, 4)
# z1 = F.interpolate(z1, scale_factor=2)
z2 = self.conv1(z1)
# z2 = F.interpolate(z2, scale_factor=2)
z3 = self.conv2(z2)
# z3 = F.interpolate(z3, scale_factor=2)
z4 = self.conv3(z3)
# z4 = F.interpolate(z4, scale_factor=2)
z5 = self.conv4(z4)
z5 = F.interpolate(z5, scale_factor=2)
z6 = self.conv5(z5)
return z6
generator_overfit()

leaky_slope = 0.2
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
kernel_size = 4
stride = 2
padding = 1
# input shape (n, 3, 128, 128)
self.pre_conv1 = nn.Sequential(
nn.Conv2d(3, 32, kernel_size, stride, padding, bias=False),
nn.LeakyReLU(leaky_slope),
nn.BatchNorm2d(32)
)
# shape (n, 32, 64, 64)
self.conv1 = nn.Sequential(
nn.Conv2d(32, 64, kernel_size, stride, padding, bias=False),
nn.LeakyReLU(leaky_slope),
nn.BatchNorm2d(64)
)
# post_conv1 shape (n, 64, 32, 32)
self.conv2 = nn.Sequential(
nn.Conv2d(64, 128, kernel_size, stride, padding, bias=False),
nn.LeakyReLU(leaky_slope),
nn.BatchNorm2d(128)
)
# post_conv2 shape (n, 128, 16, 16)
self.conv3 = nn.Sequential(
nn.Conv2d(128, 256, kernel_size, stride, padding, bias=False),
nn.LeakyReLU(leaky_slope),
nn.BatchNorm2d(256)
)
# post_conv3 shape (n, 256, 8, 8)
self.conv4 = nn.Sequential(
nn.Conv2d(256, 512, kernel_size, stride, padding, bias=False),
nn.LeakyReLU(leaky_slope),
nn.BatchNorm2d(512)
)
# post_conv4 shape (n, 512, 4, 4)
# self.conv5 = nn.Sequential(
# nn.Conv2d(512, 1, kernel_size=1, stride=1, padding=0, bias=False),
# nn.LeakyReLU(leaky_slope)
# )
# post_conv5 shape (n, 1, 4, 4)
self.dense = nn.Sequential(
nn.Linear(512*4*4, 1) #, WGAN has no sigmoid, it's a critic and returns a score (the discriminator loss).
# nn.Sigmoid()
)
# post_dense shape (n, 1)
def forward(self, x):
x1 = self.pre_conv1(x)
x2 = self.conv1(x1)
x3 = self.conv2(x2)
x4 = self.conv3(x3)
x5 = self.conv4(x4)
# x6 = self.conv5(x5)
out = self.dense(x5.view(-1, 512*4*4))
return out
discriminator_overfit()
Save the model weights and a sample output image each epoch.
This allows for resumption of training, results comparison, etc.
def model_checkpoint(CHECKPOINT_FOLDER='checkpoints/test/', epoch="?"):
# Filenames
FILE_FORMAT = '{time}_epoch{epoch}_{which}.weights'
IMAGE_FORMAT = '{time}_epoch{epoch}_generated.jpg'
INTERPOLATED_FORMAT = '{time}_epoch{epoch}_interpolated.jpg'
strftime = datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
filename_gen = FILE_FORMAT.format(epoch=epoch, which='gen', time=strftime)
filename_dis = FILE_FORMAT.format(epoch=epoch, which='dis', time=strftime)
filename_img = IMAGE_FORMAT.format(epoch=epoch, time=strftime)
filename_interpolated = INTERPOLATED_FORMAT.format(epoch=epoch, time=strftime)
# Create necessary folder
if not os.path.exists(CHECKPOINT_FOLDER):
os.makedirs(CHECKPOINT_FOLDER)
# More filenames
GEN_PATH = CHECKPOINT_FOLDER+filename_gen
DIS_PATH = CHECKPOINT_FOLDER+filename_dis
IMG_PATH = CHECKPOINT_FOLDER+filename_img
INTERPOLATED_PATH = CHECKPOINT_FOLDER + filename_interpolated
CURRENT_GEN_PATH = CHECKPOINT_FOLDER + 'current_gen.weights'
CURRENT_DIS_PATH = CHECKPOINT_FOLDER + 'current_dis.weights'
CURRENT_IMG_PATH = CHECKPOINT_FOLDER + 'current_generated.jpg'
CURRENT_INTERPOLATED_PATH = CHECKPOINT_FOLDER + 'current_interpolated.jpg'
# Save the model weights
torch.save(gen.state_dict(), GEN_PATH)
torch.save(gen.state_dict(), CURRENT_GEN_PATH)
torch.save(dis.state_dict(), DIS_PATH)
torch.save(dis.state_dict(), CURRENT_DIS_PATH)
# Save a 4x4 grid of generated images
img_grid = gen(generate_z(n=16).cuda()).detach()
save_image(img_grid, IMG_PATH, nrow=4)
save_image(img_grid, CURRENT_IMG_PATH, nrow=4)
# Interpolate at 8 points
z1, z2 = generate_z(n=1), generate_z(n=1)
epsilons = np.linspace(0, 1, 8)
zs = [e*z1 + (1-e)*z2 for e in epsilons]
z = torch.stack(zs, dim=0).squeeze(1)
interpolation_img = gen(z.cuda()).detach()
save_image(interpolation_img, INTERPOLATED_PATH)
save_image(interpolation_img, CURRENT_INTERPOLATED_PATH)
Create or reload the model. Train, saving at specified checkpoints. Save model outputs into a new folder each time.
# n_images % batch_size
gen = nn.DataParallel(Generator().cuda())
dis = nn.DataParallel(Discriminator().cuda())
if False:
GEN_PATH = "checkpoints/GPv2-Oct27/2018-10-27_15:00:52_epoch4_gen.weights"
DIS_PATH = "checkpoints/GPv2-Oct27/2018-10-27_15:00:52_epoch4_dis.weights"
gen.load_state_dict(torch.load(GEN_PATH))
dis.load_state_dict(torch.load(DIS_PATH))
CHECKPOINT_FOLDER = 'checkpoints/Oct30-128px-iter3-twogpus/'
batch_size = 128
n_batches = n_images // batch_size # 3125
n_epochs = 25
n_batches_per_epoch = n_batches
n_image_every = 1000
n_critic = 5
lr_gen = 0.001
lr_dis = 0.00001 # or 0.00001
beta1 = 0.5 # Adam momentum term
lmbda = 10 # gradient penalty term
c = 0.05 # discriminator weight clamping
# objective = nn.BCELoss()
objective = lambda y_true, y_pred: (y_true * y_pred).mean()
optimizer_gen = optim.Adam(gen.parameters(), lr=lr_gen, betas=(beta1, 0.999))
optimizer_dis = optim.Adam(dis.parameters(), lr=lr_dis, betas=(beta1, 0.999))
# Since our loss is label*score, this trains our model to give real images a negative score
# and fake images a positive score. It probably makes more sense to give real images a positive score,
# but we'll shelve that for now. https://github.com/keras-team/keras-contrib/issues/280
real_label = 1
fake_label = -1
celeb_loader = DataLoader(
dataset=celeb_dataset,
batch_size=batch_size,
pin_memory=True,
shuffle=True
)
print("Number of Batches Per Epoch: {}/{}".format(n_batches_per_epoch, n_batches))
for i_epoch in range(n_epochs):
print("#############")
print("## Epoch {} #".format(i_epoch))
print("#############")
# Save intermediate progress
model_checkpoint(CHECKPOINT_FOLDER, i_epoch)
# Train
for i_batch, x_truth in enumerate(celeb_loader):
x_truth = x_truth.cuda(non_blocking=True)
batch_size_actual = x_truth.shape[0]
####################
# Update D network #
####################
dis.zero_grad()
# train with real
label = torch.full((batch_size,), real_label).cuda()
output = dis(x_truth)
err_dis_real = objective(output, label)
err_dis_real.backward()
D_x = output.mean().item()
# train with fake
z = generate_z(n=batch_size_actual)
x_fake = gen(z.cuda(non_blocking=True))
label.fill_(fake_label)
output = dis(x_fake.detach())
err_dis_fake = objective(output, label)
err_dis_fake.backward()
D_xfake_1 = output.mean().item()
err_dis = err_dis_real + err_dis_fake
optimizer_dis.step()
# for p in dis.parameters():
# p.data.clamp_(-c, c)
# gradient penalty
# see https://discuss.pytorch.org/t/gradient-penalty-with-respect-to-the-network-parameters/11944/2
epsilon = np.random.uniform(0, 1)
x_middle = (epsilon*x_truth + (1-epsilon)*x_fake)
output = dis(x_middle.cuda(non_blocking=True)).mean()
grad_wrt_x = torch.autograd.grad(output, x_middle, create_graph=True)[0]
loss_gp = ((grad_wrt_x.view(batch_size_actual, -1).norm(dim=1) - 1)**2).mean()
loss_gp.backward()
err_gp = loss_gp.item()
optimizer_dis.step()
if i_batch % n_critic == 0:
####################
# Update G network #
####################
gen.zero_grad()
z = generate_z(n=batch_size)
x_fake = gen(z.cuda())
label.fill_(real_label) # reverse the discriminator
output = dis(x_fake)
# Generator seeks negative loss
# minimize either log(1-D(G(z))) or -log(D(G(z))) (larger gradients earlier in training).
err_gen = objective(output, label)
err_gen.backward()
# Generator seeks negative output (meaning it's viewed as real)
D_xfake_2 = output.mean().item()
optimizer_gen.step()
# Display intermediate progress
if i_batch % n_image_every == 0:
grad_gen_norm = np.mean([param.grad.norm().item() for param in gen.parameters()])
grad_dis_norm = np.mean([param.grad.norm().item() for param in dis.parameters()])
print("Epoch {}, Batch {}".format(i_epoch, i_batch))
print("Grad_D: {:.3f}, Grad_G: {:.3f}".format(grad_dis_norm, grad_gen_norm))
print("Loss_D: {:.3f}, Loss_G: {:.3f}, Loss_GP: {:.3f}, D(real): {:.3f}, D(fake): {:.3f}/{:.3f}".format(
err_dis, err_gen, err_gp, D_x, D_xfake_1, D_xfake_2
))
display_x_tensor(x_truth[0].detach())
display_x_tensor(x_fake[0].detach())
if i_batch == n_batches_per_epoch:
break
print("done")
model_checkpoint(CHECKPOINT_FOLDER, i_epoch)
model_checkpoint()


Discriminator gradient penalty only works if we remove batch-norm.
Discriminator is terrible if we remove batch-norm, allowing stark pixelated adversarial examples.
Checkerboard pattern in generated images is caused by overlap in transposed convolutions (fx kernel size 4, stride 2). See this explanation. Upsample with a stride-1 deconvolution in the final generator layer.
Generator batch-norm makes it less prone to catastrophic forgetting/discriminator exploitation. Including gen-BN makes its initial image generation pixelated, as opposed to a flat gray. Omitting gen-BN makes its images oscillate from white to black, overfitting wildly.